import copy
from abc import ABC, abstractmethod
from collections import defaultdict
from itertools import chain, product
from typing import List
import gym
import numpy as np
from gym.spaces import Box, Dict, Discrete, Tuple


""" Abstraction of matrix game; used to pass/store all relevant information about the setting that we are working with. Not an RL environment."""

class MatrixGame(ABC):
    # currently for a single step game

    @property
    def payoff_matrix(self):
        return self.normalize_payoff_matrix(self._original_payoff_matrix)

    def normalize_payoff_matrix(self, matrix):
        self.norm_factor = np.max(matrix)
        return matrix / self.norm_factor

    @property
    @abstractmethod
    def _original_payoff_matrix(self):
        pass

    @property
    def num_agents(self):
        return self.payoff_matrix.shape[-1]

    @property
    def list_of_agents(self):
        return [f"player_{i}" for i in range(self.num_agents)]

    def payoff(self, action_dict, agent_idx):
        if not isinstance(agent_idx, int):
            assert isinstance(agent_idx, str)  # assume it's the agent's name
            agent_idx = self.agent_name_to_idx(agent_idx)

        actions_tuple = tuple(
            action_dict[self.list_of_agents[i]] for i in range(self.num_agents)
        )
        return self.payoff_matrix[actions_tuple + (agent_idx,)]

    # Not currently be used
    def mixed_payoff(self, action_dict, agent_idx, leader_idx=0):
        cvx_combo = action_dict[self.list_of_agents[leader_idx]]
        assert leader_idx == 0

        # TODO account for leader_idx not being 0
        mixed_matrix = np.tensordot(cvx_combo, self.payoff_matrix, 1)

        actions_tuple = tuple(
            action_dict[self.list_of_agents[agent_idx]]
            for agent_idx in range(self.num_agents)
            if agent_idx != leader_idx
        )

        return mixed_matrix[actions_tuple + (agent_idx,)]

    def action_space(self, agent_idx):
        if not isinstance(agent_idx, int):
            assert isinstance(agent_idx, str)  # assume it's the agent's name
            agent_idx = self.agent_name_to_idx(agent_idx)

        return self.payoff_matrix.shape[agent_idx]

    def agent_name_to_idx(self, agent_name):
        return int(
            (agent_name.split("_"))[-1]
        )  # only works with current agent name representation defined in list_of_agents


""" The first two matrix games are from Bi-level Actor-Critic for Multi-agent Coordination (Zhang et al. 2020).
    The third one is from "On Stackelberg mixed strategies" (Conitzer 2017). """

class MatrixGameOne(MatrixGame):
    @property
    def _original_payoff_matrix(self):
        return np.array(
            [
                [[15, 15], [10, 10], [0, 0]],
                [[10, 10], [10, 10], [0, 0]],
                [[0, 0], [0, 0], [30, 30]],
            ]
        )

class MatrixGameTwo(MatrixGame):
    @property
    def _original_payoff_matrix(self):
        return np.array(
            [
                [[20, 15], [0, 0], [0, 0]],
                [[30, 0], [10, 5], [0, 0]],
                [[0, 0], [0, 0], [5, 10]],
            ]
        )

class MatrixGameThree(MatrixGame):
    @property
    def _original_payoff_matrix(self):
        return np.array([[[1, 1], [3, 0]], [[0, 0], [2, 1]]])

class MatrixGameDiag1(MatrixGame):
    @property
    def _original_payoff_matrix(self):
        return np.array(
            [
                [[1, 1], [0, 0], [0, 0]],
                [[0, 0], [0, 0], [0, 0]],
                [[0, 0], [0, 0], [0, 0]],
            ]
        )

class MatrixGameDiag2(MatrixGame):
    @property
    def _original_payoff_matrix(self):
        return np.array(
            [
                [[0, 0], [0, 0], [0, 0]],
                [[0, 0], [1, 1], [0, 0]],
                [[0, 0], [0, 0], [0, 0]],
            ]
        )

class MatrixGameDiag3(MatrixGame):
    @property
    def _original_payoff_matrix(self):
        return np.array(
            [
                [[0, 0], [0, 0], [0, 0]],
                [[0, 0], [0, 0], [0, 0]],
                [[0, 0], [0, 0], [1, 1]],
            ]
        )

class MatrixDesignGame(MatrixGame):
    @property
    def _original_payoff_matrix(self):
        return np.array(
            [
                [[3, 3], [6, 4]],
                [[4, 6], [2, 2]],
            ]
        )